#pragma once

#include <iostream>
#include <string>
#include <unistd.h>
#include <cstring>
#include <cstdlib>

#include <sys/socket.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#include "Log.hpp"
#include "Common.hpp"
#include "InetAddr.hpp"

using namespace LogModule;

const int gdefaultsockfd = -1;
const int gbacklog = 8;

namespace SocketModule
{
    class Socket;
    using SockPtr = std::shared_ptr<Socket>;

    // 模版方法模式
    // 基类: 规定创建Socket方法
    class Socket
    {
    public:
        virtual ~Socket() = default;
        virtual void SocketOrDie() = 0;
        virtual void SetSocketOpt() = 0;
        virtual bool BindOrDie(int port) = 0;
        virtual bool ListenOrDie() = 0;
        virtual SockPtr AcceptOrDie(InetAddr *client) = 0;
        virtual void Close() = 0;
        virtual int Recv(std::string *out) = 0;
        virtual int Send(const std::string &in) = 0;
        virtual int Fd() = 0;

        // 提供创建TCP套接字的固定格式
        void BuildTcpSocketMethod(int port)
        {
            SocketOrDie();
            SetSocketOpt();
            BindOrDie(port);
            ListenOrDie();
        }
    };

    class TcpSocket : public Socket
    {
    public:
        TcpSocket(int sockfd = gdefaultsockfd)
            : _sockfd(sockfd)
        {}

        virtual ~TcpSocket() {}

        virtual void SocketOrDie() override
        {
            _sockfd = ::socket(AF_INET, SOCK_STREAM, 0);
            if (_sockfd < 0)
            {
                LOG(LogLevel::DEBUG) << "socket error";
                exit(SOCKET_ERR);
            }
            LOG(LogLevel::DEBUG) << "socket success, sockfd: " << _sockfd;
        }

        virtual void SetSocketOpt() override
        {
            // 保证服务器在异常断开之后可以立即重启, 不会存在bind error问题!
            int opt = 1;
            ::setsockopt(_sockfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
        }

        virtual bool BindOrDie(int port) override
        {
            if (_sockfd == gdefaultsockfd)
                return false;
            InetAddr addr(port);
            int n = ::bind(_sockfd, addr.NetAddr(), addr.NetAddrLen());
            if (n < 0)
            {
                LOG(LogLevel::DEBUG) << "bind error";
                exit(BIND_ERR);
            }
            LOG(LogLevel::DEBUG) << "bind success, sockfd: " << _sockfd;
            return true;
        }

        virtual bool ListenOrDie() override
        {
            if (_sockfd == gdefaultsockfd)
                return false;
            int n = ::listen(_sockfd, gbacklog);
            if (n < 0)
            {
                LOG(LogLevel::DEBUG) << "listen error";
                exit(LISTEN_ERR);
            }
            LOG(LogLevel::DEBUG) << "listen success, sockfd: " << _sockfd;
            return true;
        }

        // 返回: 文件描述符 && 客户端信息
        virtual SockPtr AcceptOrDie(InetAddr *client) override
        {
            struct sockaddr_in peer;
            socklen_t len = sizeof(peer);
            int newsockfd = ::accept(_sockfd, CONV(&peer), &len);
            if (newsockfd < 0)
            {
                LOG(LogLevel::DEBUG) << "accept error";
                return nullptr;
            }
            client->SetAddr(peer);
            return std::make_shared<TcpSocket>(newsockfd);
        }

        virtual void Close() override
        {
            if (_sockfd == gdefaultsockfd)
                return;
            ::close(_sockfd);
        }

        virtual int Recv(std::string *out) override
        {
            char buffer[1024 * 8];
            int n = ::recv(_sockfd, buffer, sizeof(buffer) - 1, 0);
            if(n > 0)
            {
                buffer[n] = 0;
                *out = buffer;
            }
            return n;
        }

        virtual int Send(const std::string &in) override
        {
            int n = ::send(_sockfd, in.c_str(), in.size(), 0);
            return n;
        }

        virtual int Fd() override
        {
            return _sockfd;
        }

    private:
        int _sockfd;
    };

    // int main()
    // {
    //     Socket *sk = new TcpSocket();
    //     sk->BuildTcpSocket(8080);
    //
    //     return 0;
    // }
}